{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Text to Multiclass Explanation: Language Modeling Example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook demostrates how to get explanations for the top-k next words generated by a language model. In this demo, we use the pretrained gpt2 model provided by hugging face (https://huggingface.co/gpt2) to predict the top-k next words. By looking at the top-k next words, we treat them as k separate classes and then learn the explanations for each of these k words. We thereby are able to explain the contribution of words in the input that are responsible for the likelihood of the top-k next words to be predicted. " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "\n", "import shap" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load model and tokenizer" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\", use_fast=True)\n", "model = AutoModelForCausalLM.from_pretrained(\"gpt2\").cuda()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We next wrap the model with the TopKLM model which extracts the log odds of the top-k next words and also create a Text masker by initializing it with the mask_token = \"...\" and set collapse_mask_token = True, which is used for infilling text during perturbation of the inputs." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "wrapped_model = shap.models.TopKLM(model, tokenizer, k=100)\n", "masker = shap.maskers.Text(tokenizer, mask_token=\"...\", collapse_mask_token=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we set the initial text for which we want the gpt2 model to predict the next word" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "s = [\"In a shocking finding, scientists discovered a herd of unicorns living in a\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create explainer object" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "explainer = shap.Explainer(wrapped_model, masker)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compute SHAP values" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "shap_values = explainer(s)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize the SHAP values across the input sentence for the top-k next words" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now see the top-k next words predicted by gpt2 under \"Output Text\" in the plot below and hover over each of the token to understand which words in the input sentence are driving the generation of the particular output word to be predicted" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
[0]
\n", "
\n", "
\n", "\n", "
outputs
\n", "
cave
\n", "
forest
\n", "
small
\n", "
desert
\n", "
tiny
\n", "
\"
\n", "
remote
\n", "
zoo
\n", "
tree
\n", "
field
\n", "
house
\n", "
nest
\n", "
tropical
\n", "
lake
\n", "
large
\n", "
mountain
\n", "
farm
\n", "
group
\n", "
wild
\n", "
very
\n", "
single
\n", "
barn
\n", "
jungle
\n", "
new
\n", "
valley
\n", "
world
\n", "
garden
\n", "
herd
\n", "
grass
\n", "
natural
\n", "
park
\n", "
swamp
\n", "
laboratory
\n", "
nearby
\n", "
well
\n", "
rural
\n", "
pond
\n", "
dark
\n", "
wood
\n", "
subter
\n", "
room
\n", "
lab
\n", "
cage
\n", "
huge
\n", "
New
\n", "
water
\n", "
colony
\n", "
massive
\n", "
common
\n", "
state
\n", "
deep
\n", "
home
\n", "
man
\n", "
mine
\n", "
human
\n", "
rock
\n", "
region
\n", "
box
\n", "
river
\n", "
part
\n", "
hollow
\n", "
c
\n", "
hole
\n", "
vast
\n", "
village
\n", "
different
\n", "
virtual
\n", "
city
\n", "
strange
\n", "
greenhouse
\n", "
frozen
\n", "
shallow
\n", "
semi
\n", "
flat
\n", "
patch
\n", "
mysterious
\n", "
local
\n", "
giant
\n", "
sub
\n", "
barren
\n", "
special
\n", "
mountainous
\n", "
mud
\n", "
cemetery
\n", "
pod
\n", "
hive
\n", "
newly
\n", "
closed
\n", "
community
\n", "
California
\n", "
place
\n", "
flooded
\n", "
prehistoric
\n", "
sw
\n", "
high
\n", "
z
\n", "
hot
\n", "
far
\n", "
1
\n", "
pasture


-9-12-15-6-3-12.5736-12.5736base value-2.91492-2.91492f cave(inputs)4.799 a 2.87 in 1.266 living 0.741 orns 0.474 of 0.463 unic 0.165 herd 0.138 a 0.115 In -0.304 , -0.304 discovered -0.236 finding -0.2 shocking -0.176 a -0.153 scientists
inputs
0.115
In
0.138
a
-0.2
shocking
-0.236
finding
-0.304
,
-0.153
scientists
-0.304
discovered
-0.176
a
0.165
herd
0.474
of
0.463
unic
0.741
orns
1.266
living
2.87
in
4.799
a